29.3.1 代码生成概述#
代码生成模块是编程 Agent 的核心能力之一,它能够根据自然语言描述生成高质量的代码。代码生成涉及需求理解、架构设计、代码实现等多个环节。
代码生成流程#
用户需求 ↓ 需求分析与理解 ↓ 架构设计 ↓ 代码实现 ↓ 代码验证 ↓ 优化与改进 ↓ 最终代码
29.3.2 需求分析#
需求提取器#
pythonpython class RequirementExtractor: """需求提取器""" def __init__(self, llm_client: LLMClient): self.llm_client = llm_client async def extract(self, user_request: str) -> Requirement: """提取需求""" prompt = f""" 分析用户需求,提取关键信息: 用户需求:{user_request} 请提取以下信息: 1. 功能需求(需要实现什么功能) 2. 技术栈(使用的编程语言、框架等) 3. 约束条件(性能、安全、兼容性等) 4. 输入输出(预期的输入和输出) 5. 特殊要求(代码风格、注释要求等) 以 JSON 格式返回结果。 """ response = await self.llm_client.complete(prompt) return self._parse_requirement(response) def _parse_requirement(self, response: str) -> Requirement: """解析需求""" try: data = json.loads(response) return Requirement( functional_requirements=data.get('functional_requirements', []), tech_stack=data.get('tech_stack', {}), constraints=data.get('constraints', {}), inputs=data.get('inputs', []), outputs=data.get('outputs', []), special_requirements=data.get('special_requirements', {}) ) except json.JSONDecodeError: raise ValueError("Invalid requirement format") ```### 需求验证器 class RequirementValidator: """需求验证器""" def validate(self, requirement: Requirement) -> ValidationResult: """验证需求""" issues = [] # 检查功能需求 if not requirement.functional_requirements: issues.append("No functional requirements specified") # 检查技术栈 if not requirement.tech_stack: issues.append("No tech stack specified") # 检查约束条件 if 'performance' in requirement.constraints: perf = requirement.constraints['performance'] if not isinstance(perf, dict) or 'max_time' not in perf: issues.append("Invalid performance constraint") return ValidationResult( valid=len(issues) == 0, issues=issues )
29.3.3 架构设计#
架构设计器#
python```python class ArchitectureDesigner: """架构设计器""" def __init__(self, llm_client: LLMClient): self.llm_client = llm_client self.design_patterns = self._load_design_patterns() async def design(self, requirement: Requirement) -> Architecture: """设计架构""" prompt = f""" 根据需求设计软件架构: 功能需求:{requirement.functional_requirements} 技术栈:{requirement.tech_stack} 约束条件:{requirement.constraints} 请设计: 1. 系统架构(模块划分、层次结构) 2. 类设计(类、接口、继承关系) 3. 数据结构(数据模型、存储方案) 4. 接口设计(API、函数签名) 5. 设计模式(适用的设计模式) 以 JSON 格式返回架构设计。 """ response = await self.llm_client.complete(prompt) return self._parse_architecture(response) def _parse_architecture(self, response: str) -> Architecture: """解析架构""" try: data = json.loads(response) return Architecture( system_architecture=data.get('system_architecture', {}), class_design=data.get('class_design', []), data_structures=data.get('data_structures', []), interfaces=data.get('interfaces', []), design_patterns=data.get('design_patterns', []) ) except json.JSONDecodeError: raise ValueError("Invalid architecture format") def _load_design_patterns(self) -> Dict[str, DesignPattern]: """加载设计模式""" return { 'singleton': DesignPattern( name='Singleton', description='确保一个类只有一个实例', 适用场景='需要全局唯一访问点' ), 'factory': DesignPattern( name='Factory', description='创建对象的接口', 适用场景='需要灵活创建对象' ), 'observer': DesignPattern( name='Observer', description='定义对象间的一对多依赖', 适用场景='需要事件通知机制' ) } ```### 架构评估器 class ArchitectureEvaluator: """架构评估器""" def evaluate(self, architecture: Architecture, requirement: Requirement) -> EvaluationResult: """评估架构""" scores = {} # 评估模块化 scores['modularity'] = self._evaluate_modularity(architecture) # 评估可扩展性 scores['extensibility'] = self._evaluate_extensibility(architecture) # 评估性能 scores['performance'] = self._evaluate_performance( architecture, requirement ) # 评估可维护性 scores['maintainability'] = self._evaluate_maintainability(architecture) # 计算总分 total_score = sum(scores.values()) / len(scores) return EvaluationResult( total_score=total_score, scores=scores, recommendations=self._generate_recommendations(scores) ) def _evaluate_modularity(self, architecture: Architecture) -> float: """评估模块化""" # 检查模块划分 modules = architecture.system_architecture.get('modules', []) if not modules: return 0.0 # 模块越多,模块化程度越高 score = min(len(modules) / 10.0, 1.0) return score def _evaluate_extensibility(self, architecture: Architecture) -> float: """评估可扩展性""" # 检查设计模式使用 patterns = architecture.design_patterns if not patterns: return 0.5 # 使用设计模式提高可扩展性 score = 0.5 + min(len(patterns) / 5.0, 0.5) return score def _evaluate_performance(self, architecture: Architecture,
requirement: Requirement) -> float: """评估性能"""
检查性能约束
constraints = requirement.constraints.get('performance', {}) if not constraints: return 0.8 # 默认分数
评估架构是否满足性能要求
score = 0.8 # 基础分数
检查缓存策略
if 'caching' in architecture.system_architecture: score += 0.1
检查并发处理
if 'concurrency' in architecture.system_architecture: score += 0.1 return min(score, 1.0) def _evaluate_maintainability(self, architecture: Architecture) -> float: """评估可维护性"""
检查类设计
classes = architecture.class_design if not classes: return 0.5
评估类的复杂度
avg_methods = sum( len(c.get('methods', [])) for c in classes ) / len(classes)
方法数量适中,可维护性高
if 5 <= avg_methods <= 15: score = 1.0 elif avg_methods < 5: score = 0.8 else: score = 0.6 return score def _generate_recommendations(self, scores: Dict[str, float]) -> List[str]: """生成建议""" recommendations = [] if scores['modularity'] < 0.7: recommendations.append( "建议增加模块划分,提高模块化程度" ) if scores['extensibility'] < 0.7: recommendations.append( "建议使用更多设计模式,提高可扩展性" ) if scores['maintainability'] < 0.7: recommendations.append( "建议简化类设计,降低复杂度" ) return recommendations
bash## 29.3.4 代码实现 ### 代码生成器 ```python ```python class CodeGenerator: """代码生成器""" def __init__(self, llm_client: LLMClient): self.llm_client = llm_client self.code_templates = self._load_code_templates() async def generate(self, architecture: Architecture, requirement: Requirement) -> GeneratedCode: """生成代码""" # 生成类代码 class_codes = [] for class_design in architecture.class_design: code = await self._generate_class_code( class_design, requirement ) class_codes.append(code) # 生成接口代码 interface_codes = [] for interface in architecture.interfaces: code = await self._generate_interface_code( interface, requirement ) interface_codes.append(code) # 生成主程序代码 main_code = await self._generate_main_code( architecture, requirement ) # 组合所有代码 full_code = self._combine_codes( class_codes, interface_codes, main_code ) return GeneratedCode( full_code=full_code, class_codes=class_codes, interface_codes=interface_codes, main_code=main_code ) async def _generate_class_code(self, class_design: Dict, requirement: Requirement) -> str: """生成类代码""" prompt = f""" 根据类设计生成代码: 类名:{class_design.get('name')} 方法:{class_design.get('methods', [])} 属性:{class_design.get('attributes', [])} 父类:{class_design.get('parent', 'None')} 编程语言:{requirement.tech_stack.get('language', 'Python')} 请生成完整的类代码,包括: 1. 类定义 2. 所有方法的实现 3. 必要的注释 4. 错误处理 """ return await self.llm_client.complete(prompt) async def _generate_interface_code(self, interface: Dict, requirement: Requirement) -> str: """生成接口代码""" prompt = f""" 根据接口设计生成代码: 接口名:{interface.get('name')} 方法:{interface.get('methods', [])} 编程语言:{requirement.tech_stack.get('language', 'Python')} 请生成完整的接口代码。 """ return await self.llm_client.complete(prompt) async def _generate_main_code(self, architecture: Architecture, requirement: Requirement) -> str: """生成主程序代码""" prompt = f""" 根据架构和需求生成主程序代码: 功能需求:{requirement.functional_requirements} 类:{[c.get('name') for c in architecture.class_design]} 接口:{[i.get('name') for i in architecture.interfaces]} 编程语言:{requirement.tech_stack.get('language', 'Python')} 请生成主程序代码,包括: 1. 初始化代码 2. 主要业务逻辑 3. 示例用法 """ return await self.llm_client.complete(prompt) def _combine_codes(self, class_codes: List[str], interface_codes: List[str], main_code: str) -> str: """组合代码""" combined = [] # 添加导入 combined.append("# Generated Code") combined.append("") # 添加接口 if interface_codes: combined.append("# Interfaces") for code in interface_codes: combined.append(code) combined.append("") # 添加类 if class_codes: combined.append("# Classes") for code in class_codes: combined.append(code) combined.append("") # 添加主程序 combined.append("# Main Program") combined.append(main_code) return "\n".join(combined) ```### 代码优化器 class CodeOptimizer: """代码优化器""" def __init__(self, llm_client: LLMClient): self.llm_client = llm_client async def optimize(self, code: str, requirement: Requirement) -> OptimizedCode: """优化代码""" # 分析代码问题 issues = await self._analyze_issues(code) # 生成优化建议 suggestions = await self._generate_suggestions( code, issues, requirement ) # 应用优化 optimized_code = await self._apply_optimizations( code, suggestions ) return OptimizedCode( original_code=code, optimized_code=optimized_code, issues=issues, suggestions=suggestions ) async def _analyze_issues(self, code: str) -> List[CodeIssue]: """分析代码问题""" prompt = f""" 分析以下代码的问题: {code} 请识别: 1. 性能问题 2. 安全问题 3. 代码风格问题 4. 潜在的 bug 5. 可维护性问题 以 JSON 格式返回问题列表。 """ response = await self.llm_client.complete(prompt) return self._parse_issues(response) async def _generate_suggestions(self, code: str, issues: List[CodeIssue], requirement: Requirement) -> List[Suggestion]: """生成优化建议""" prompt = f""" 基于代码问题生成优化建议: 代码:{code} 问题:{issues} 约束条件:{requirement.constraints} 请生成具体的优化建议,包括: 1. 问题描述 2. 优化方案 3. 预期效果 以 JSON 格式返回建议列表。 """ response = await self.llm_client.complete(prompt) return self._parse_suggestions(response) async def _apply_optimizations(self, code: str, suggestions: List[Suggestion]) -> str: """应用优化""" optimized_code = code for suggestion in suggestions: if suggestion.applicable: optimized_code = await self._apply_suggestion( optimized_code, suggestion ) return optimized_code async def _apply_suggestion(self, code: str, suggestion: Suggestion) -> str: """应用单个建议""" prompt = f""" 应用以下优化建议到代码: 原始代码:{code} 优化建议:{suggestion.description} 优化方案:{suggestion.solution} 请返回优化后的代码。 """ return await self.llm_client.complete(prompt)
29.3.5 代码验证#
代码验证器#
python```python class CodeValidator: """代码验证器""" def __init__(self, tool_manager: ToolManager): self.tool_manager = tool_manager async def validate(self, code: str, requirement: Requirement) -> ValidationResult: """验证代码""" results = [] # 语法检查 syntax_result = await self._check_syntax(code, requirement) results.append(syntax_result) # 类型检查 type_result = await self._check_types(code, requirement) results.append(type_result) # 逻辑检查 logic_result = await self._check_logic(code, requirement) results.append(logic_result) # 性能检查 performance_result = await self._check_performance( code, requirement ) results.append(performance_result) # 综合结果 all_passed = all(r.passed for r in results) return ValidationResult( passed=all_passed, results=results, issues=self._collect_issues(results) ) async def _check_syntax(self, code: str, requirement: Requirement) -> CheckResult: """检查语法""" language = requirement.tech_stack.get('language', 'python') try: if language == 'python': result = await self._check_python_syntax(code) else: result = CheckResult( check_type='syntax', passed=True, message=f"Syntax check for {language} not implemented" ) return result except Exception as e: return CheckResult( check_type='syntax', passed=False, message=f"Syntax error: {str(e)}" ) async def _check_python_syntax(self, code: str) -> CheckResult: """检查 Python 语法""" try: compile(code, '<string>', 'exec') return CheckResult( check_type='syntax', passed=True, message="Syntax is valid" ) except SyntaxError as e: return CheckResult( check_type='syntax', passed=False, message=f"Syntax error at line {e.lineno}: {e.msg}" ) async def _check_types(self, code: str, requirement: Requirement) -> CheckResult: """检查类型""" # 使用类型检查工具 tool = self.tool_manager.get_tool('type_checker') if not tool: return CheckResult( check_type='type', passed=True, message="Type checker not available" ) try: result = await tool.execute({'code': code}) if result.success: return CheckResult( check_type='type', passed=True, message="Type check passed" ) else: return CheckResult( check_type='type', passed=False, message=f"Type check failed: {result.error}" ) except Exception as e: return CheckResult( check_type='type', passed=False, message=f"Type check error: {str(e)}" ) async def _check_logic(self, code: str, requirement: Requirement) -> CheckResult: """检查逻辑""" # 分析代码逻辑 issues = [] # 检查空指针 if 'None' in code and 'if' not in code: issues.append("Potential None reference without check") # 检查资源泄漏 if 'open(' in code and 'close(' not in code: issues.append("Potential resource leak (file not closed)") if issues: return CheckResult( check_type='logic', passed=False, message=f"Logic issues: {', '.join(issues)}" ) else: return CheckResult( check_type='logic', passed=True, message="Logic check passed" ) async def _check_performance(self, code: str, requirement: Requirement) -> CheckResult: """检查性能""" issues = [] # 检查嵌套循环 if code.count('for ') > 2: issues.append("Deep nested loops may cause performance issues") # 检查大列表操作 if 'list(' in code and 'range(' in code: issues.append("Consider using generator expressions for large ranges") if issues: return CheckResult( check_type='performance', passed=False, message=f"Performance issues: {', '.join(issues)}" ) else: return CheckResult( check_type='performance', passed=True, message="Performance check passed" ) def _collect_issues(self, results: List[CheckResult]) -> List[str]: """收集所有问题""" issues = [] for result in results: if not result.passed: issues.append(result.message) return issues
通过实现这些组件,我们可以构建一个完整的代码生成模块,能够从需求分析到代码验证的全流程自动化。